import taichi as ti

from ..material import Material
from ..utils import Vec3, reflect, sqr, Ray3


@ti.func
def refract(uv: Vec3, norm: Vec3, eta_i_over_eta_t: float) -> Vec3:
    """折射"""
    rOutPer_p = eta_i_over_eta_t * (uv - norm.dot(uv) * norm)
    rOutParallel = abs(1 - rOutPer_p.dot(rOutPer_p)) ** 0.5 * norm
    return rOutPer_p - rOutParallel


@ti.func
def reflectance(cosine: float, ref_idx: float) -> float:
    r0 = (1 - ref_idx) / (1 + ref_idx)
    r0 = sqr(r0)
    return r0 + (1 - r0) * (1 - cosine) ** 5


class Dielectric(Material):
    """电介质材质"""

    def __init__(self, ratio: float = 1.5):
        super().__init__()
        Material.materials[self.index].attribute[0] = ratio

    @ti.func
    def scatter(self: int, ray: Ray3, hitPoint: Vec3, norm: Vec3):
        ray.origin = hitPoint
        color = Material.getColor(self, hitPoint)
        ray.direct = ray.direct.normalized()
        ratio = Material.materials[self].attribute[0]
        if ray.direct.dot(norm) > 0:
            norm *= -1
        else:
            ratio = 1 / ratio
        cosTheta = -ray.direct.dot(norm)
        sinTheta = (1 - cosTheta * cosTheta) ** 0.5
        if ratio * sinTheta > 1 or ti.random() < reflectance(cosTheta, ratio):
            ray.direct = reflect(ray.direct, norm)
        else:
            ray.direct = refract(ray.direct, norm, ratio)
        return color, ray
